{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os, csv\n",
    "from tqdm import tqdm, trange"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# file paths\n",
    "data_dir = \"/home/ubuntu/bhargav/data/\"\n",
    "dataset = \"yelp\" # amazon / yelp / imagecaption\n",
    "train_0 = os.path.join(data_dir ,\"./{}/sentiment_train_0.txt\".format(dataset))\n",
    "train_1 = os.path.join(data_dir,\"./{}/sentiment_train_1.txt\".format(dataset))\n",
    "test_0 = os.path.join(data_dir,\"./{}/sentiment_test_0.txt\".format(dataset))\n",
    "test_1 = os.path.join(data_dir,\"./{}/sentiment_test_1.txt\".format(dataset))\n",
    "dev_0 = os.path.join(data_dir,\"./{}/sentiment_dev_0.txt\".format(dataset))\n",
    "dev_1 = os.path.join(data_dir,\"./{}/sentiment_dev_1.txt\".format(dataset))\n",
    "reference_0 = os.path.join(data_dir,\"./{}/reference_0.txt\".format(dataset))\n",
    "reference_1 = os.path.join(data_dir,\"./{}/reference_1.txt\".format(dataset))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_out = os.path.join(data_dir,\"./{}/bert_classifier_training/train.csv\".format(dataset))\n",
    "dev_out = os.path.join(data_dir,\"./{}/bert_classifier_training/dev.csv\".format(dataset))\n",
    "test_out = os.path.join(data_dir,\"./{}/bert_classifier_training/test.csv\".format(dataset))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_classification_file(input_file_path_list, input_file_label_list, output_file_path):\n",
    "    \"\"\"\n",
    "    Create a csv file combining training data for BERT classification training.\n",
    "    input_file_path_list : Path of the input files\n",
    "    input_file_label_list : Correspoding labels for input_files\n",
    "    output_file_path : csv file path\n",
    "    \"\"\"\n",
    "    with open(output_file_path, \"w\") as out_fp:\n",
    "        writer = csv.writer(out_fp, delimiter=\"\\t\")\n",
    "        for i, file in enumerate(tqdm(input_file_path_list)):\n",
    "            with open(file) as fp:\n",
    "                for line in fp:\n",
    "                    writer.writerow([line.strip(),input_file_label_list[i]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "create_classification_file([train_0,train_1],[0,1], train_out)\n",
    "create_classification_file([test_0,test_1],[0,1], test_out)\n",
    "create_classification_file([dev_0,dev_1],[0,1], dev_out)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
